import os
import time
import datetime
import json

from tqdm import tqdm
import argparse
from argparse import ArgumentParser
import numpy as np
from rich import print as rprint

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 
os.environ.pop("CUDA_VISIBLE_DEVICES", None)

from distutils.util import strtobool

def bool_argument(value):
    """Convert a string value to boolean."""
    return bool(strtobool(value))

import importlib_metadata
VERSION = importlib_metadata.version("overcooked_ai")
print(f'\n----This overcook version is {VERSION}----\n')

from overcooked_ai_py.mdp.overcooked_mdp import OvercookedGridworld
from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv
from overcooked_ai_py.agents.agent import AgentGroup
from overcooked_ai_py.mdp.actions import Action


from causal_graph.SCMOp_Predict_Next_Action_Only_2pot import SCMOp_Predict_Next_Action_Only_2pot
from causal_graph.SCMOp_Predict_Next_Action_Only import SCMOp_Predict_Next_Action_Only



from utils import NEW_LAYOUTS, OLD_LAYOUTS, make_agent, set_seed, save_result, save_state_pkl, read_state_pkl, read_action_pkl, save_action_pkl
from utils_causal import convert_state_to_input, convert_action_to_input, write_buffer_to_folder, load_buffer, save_causal_graph, load_causal_graph, load_causal_graph_DAG 


# Suppress all warnings
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", message=".*cuBLAS factory.*") # ignore "Unable to register cuBLAS factory" due to use tf-CPU
warnings.filterwarnings("ignore")


import torch
torch.set_printoptions(sci_mode=False, precision=4, linewidth=200)


OvercookedState = ['empty_hand', # empty hand 0
                'hold_onion', # holding oninon 1
                'hold_dish', # holding empty dish 2
                'dish_with_soup', # holding dish with soup 3
                'pot_0', # pot with 0 onion 4
                'pot_1', # pot with 1 onion 5
                'pot_2', # pot with 2 onions 6
                'pot_3', # pot with 3 oninons 7
                'pot_finished', # pot with soup cooked 8
                'goal_delivered', # delivered_goal 9
                'pickup(onion)', # 10
                'put_onion_in_pot()', # 11
                'pickup(dish)', # 12
                'fill_dish_with_soup()', # 13
                'deliver_soup()', # 14
                'place_onion_on_counter()', # 15
                'place_dish_on_counter()',] # 16

OvercookedState_2_pot = ['empty_hand', # empty hand 0
                'hold_onion', # holding oninon 1
                'hold_dish', # holding empty dish 2
                'dish_with_soup', # holding dish with soup 3
                'pot_0', # pot with 0 onion 4
                'pot_1', # pot with 1 onion 5
                'pot_2', # pot with 2 onions 6
                'pot_3', # pot with 3 oninons 7
                'pot_finished', # pot with soup cooked 8
                'pot_1_0', # pot with 0 onion 9
                'pot_1_1', # pot with 1 onion 10
                'pot_1_2', # pot with 2 onions 11
                'pot_1_3', # pot with 3 oninons 12
                'pot_1_finished', # pot with soup cooked 13
                'goal_delivered', # delivered_goal 14
                'pickup(onion)', # 15
                'put_onion_in_pot()', # 16
                'pickup(dish)', # 17
                'fill_dish_with_soup()', # 18
                'deliver_soup()', # 19
                'place_onion_on_counter()', # 20
                'place_dish_on_counter()',] # 21

interact_action_list = ['pickup_onion', # 15
                'put_onion_in_pot', # 16
                'pickup_dish', # 17
                'fill_dish_with_soup', # 18
                'deliver_soup', # 19
                'place_onion_on_counter', # 20
                'place_dish_on_counter',] # 21

def train_causality(data_buffer, path, train_step, layout):
    """
    Trains the causal model and returns the learned causal graph.
    Saves the causal graph to a file after training.
    If training is not required, loads the causal graph from a file.
    """
    if layout == "cramped_room":
        causal_model = SCMOp_Predict_Next_Action_Only(data_buffer, v_num=len(OvercookedState))
        
        # Train the causal model
        for iter in tqdm(range(train_step)):
            for _ in range(10):  # Adjust the number of timesteps as needed
                causal_model.train_f()
            
            # Train `s` for a few timesteps
            for _ in range(10):  # Adjust the number of timesteps as needed
                causal_model.train_s()
            
            # Log progress every 100 iterations
            if iter % 100 == 0:
                print(torch.sigmoid(causal_model.s_params.edge_params))
                causal_graph = torch.sigmoid(causal_model.best_s_params)
                save_causal_graph(causal_graph, path)
    else:
        causal_model = SCMOp_Predict_Next_Action_Only_2pot(data_buffer, v_num=len(OvercookedState))
        
        # Train the causal model
        for iter in tqdm(range(train_step)):
            for _ in range(10):  # Adjust the number of timesteps as needed
                causal_model.train_f()
            
            # Train `s` for a few timesteps
            for _ in range(10):  # Adjust the number of timesteps as needed
                causal_model.train_s()
            
            # Log progress every 100 iterations
            if iter % 100 == 0:
                print(torch.sigmoid(causal_model.s_params.edge_params))
                causal_graph = torch.sigmoid(causal_model.best_s_params)
                save_causal_graph(causal_graph, path)        


    # Extract and save the learned causal graph
    causal_graph = torch.sigmoid(causal_model.best_s_params)
    save_causal_graph(causal_graph, path)

    

def get_causal_graph(path, data_buffer, train=True, train_step=10000, layout=None):
    """
    Trains the causal model if 'train' is True, otherwise loads an existing causal graph.
    """
    if train:
        graph = train_causality(data_buffer, path, train_step, layout)
        exit()
        # train_causality_reward(data_buffer, reward_path, train_step)
        return  graph
    else:
        return load_causal_graph_DAG(path, layout)

def main(variant):

    set_seed(variant['seed'])

    layout = variant['layout']
    horizon = variant['horizon']
    episode = variant['episode']

    mode = variant['mode']
    
    if VERSION == '1.1.0':
        mdp = OvercookedGridworld.from_layout_name(NEW_LAYOUTS[layout])
    elif VERSION == '0.0.1':
        mdp = OvercookedGridworld.from_layout_name(OLD_LAYOUTS[layout])

    env = OvercookedEnv(mdp, horizon=horizon)
    env.reset()

    p0_algo = variant['p0']
    p1_algo = variant['p1']
    print(f"\n===P0 agent: {p0_algo} | P1 agent: {p1_algo}===\n")

    start_time = time.time()
    results = []

    if variant['save_buffer']:
        data_buffer = []
    else:
        data_buffer = load_buffer(path=variant['buffer_path'])
        print(f"The length of the training buffer is: {len(data_buffer)}")
        # time.sleep(10)
    
    causal_graph = get_causal_graph(path = variant["causal_graph_path"], 
                                    data_buffer = data_buffer, train=variant['train_SCM'], train_step=variant['train_SCM_step'], layout=variant['layout'])

    for i in range(episode):  
        agents_list = []
        for alg in [p0_algo, p1_algo]:
            if alg == "ProAgent":

                assert variant['gpt_planner_model']!=None, print(f'you should choose a gpt model')
                assert variant['gpt_explainer_model']!=None, print(f'you should choose a gpt model')

                print(f"\n----Use Planner: {variant['gpt_planner_model']}----\n")
                print(f"\n----Use Explainer: {variant['gpt_explainer_model']}----\n")

                agent = make_agent(alg, mdp, layout, planner_model=variant['gpt_planner_model'], explainer_model=variant['gpt_explainer_model'],   
                                   prompt_level=variant['prompt_level'], 
                                   belief_revision=variant['belief_revision'],
                                   use_causal_graph=variant['use_causal_graph'],
                                   use_failure_handled=variant['use_failure_handled'],
                                   causal_graph=causal_graph,
                                   retrival_method=variant['retrival_method'], 
                                   K=variant['K'], tuning=variant['tunning'],
                                   gamma=variant['gamma'],
                                   debug=variant['debug'])
                
            elif alg == "BC":
                agent = make_agent(alg, mdp, layout, seed_id=i)

            else:
                agent = make_agent(alg, mdp, layout)

            agents_list.append(agent)

        team = AgentGroup(*agents_list)
        team.reset()

        env.reset()
        r_total = 0
        old_state = None
        old_op_state = None

        if variant["record_action_only"]:
            interact_act_p0 = None
            interact_act_p1 = None

        if mode == 'exp':
            for t in range(horizon):
                s_t = env.state
                print(f'\n>>>>>>>>>>>>>time: {t}<<<<<<<<<<<<<<<<<<<<<\n')
                if variant["layout"] == "cramped_room":
                    state = convert_state_to_input(OvercookedState, s_t.objects, s_t.players[0].held_object, variant["layout"])
                    state_op = convert_state_to_input(OvercookedState, s_t.objects, s_t.players[1].held_object, variant["layout"])
                else:
                    state = convert_state_to_input(OvercookedState_2_pot, s_t.objects, s_t.players[0].held_object, variant["layout"])
                    state_op = convert_state_to_input(OvercookedState_2_pot, s_t.objects, s_t.players[1].held_object, variant["layout"])

                # print(s_t.objects)
                a_t = team.joint_action(s_t) 
                print(f"\n-----------Controller-----------\n")  
                print(f"action: P0 {Action.to_char(a_t[0])} | P1 {Action.to_char(a_t[1])}") 
                obs, reward, done, env_info = env.step(a_t)

                ml_actions = obs.ml_actions
                # Convert action to input
                # if record for p0
                if not variant["record_action_only"]:
                    # if variant['collect_data'] == 'p0':
                    state = convert_action_to_input(state, ml_actions[0])
                    # if record for p1
                    # elif variant['collect_data'] == 'p1':
                    state_op = convert_action_to_input(state_op, ml_actions[1])
                else:

                    if ml_actions[0] in interact_action_list:
                        interact_act_p0 = ml_actions[0]
                    if ml_actions[1] in interact_action_list:
                        interact_act_p1 = ml_actions[1] 
                    state = convert_action_to_input(state, interact_act_p0, variant["layout"])
                    state_op = convert_action_to_input(state_op, interact_act_p1, variant["layout"])

                if reward == 20:
                    if variant["layout"] == "cramped_room":
                        state[9] = 1
                        state_op[9] = 1
                    else:
                        state[14] = 1
                        state_op[14] = 1

                if variant['collect_data'] == 'p0':
                    timestep = [old_state, state, old_op_state, reward]
                else:
                    timestep = [old_op_state, state_op, old_state, reward]

                if old_state is not None and variant['save_buffer']:
                    if not variant['record_action_only']:
                        data_buffer.append(timestep)
                    else:
                        if variant["layout"] == "cramped_room":
                            if timestep[1] and any(timestep[1][i] == 1 for i in range(10, 17)):
                                data_buffer.append(timestep)
                        else:
                            if timestep[1] and any(timestep[1][i] == 1 for i in range(15, 22)):
                                data_buffer.append(timestep)

                    if t % 10 == 0:
                        write_buffer_to_folder(data_buffer, variant['buffer_path'])


                # if record for p0   
                old_state = state
                old_op_state = state_op

                skills = f""
                for i, ml_action in enumerate(ml_actions):
                    if ml_action == None:
                        continue
                    skills += f"P{i} finished <{ml_action}>. "
                print(skills)

                r_total += reward
                rprint("[red]" + f'r: {reward} | total: {r_total}\n\n')
            
            if variant['save_buffer']:
                write_buffer_to_folder(data_buffer, variant['buffer_path'])

            ## finish one episode
            if p0_algo == "ProAgent"  or p1_algo == "ProAgent":
                print(f"\n================\n")
                try: # ProAgent id = 0
                    print(f"P1's real behavior: {team.agents[0].teammate_ml_actions_dict}")
                    print(f"The infered P1's intention: {team.agents[0].teammate_intentions_dict}")
                except: # ProAgent id = 1
                    print(f"P0's real behavior: {team.agents[1].teammate_ml_actions_dict}")
                    print(f"The infered P0's intention: {team.agents[1].teammate_intentions_dict}")
                print(f"\n================\n")
 
        elif mode == 'demo':
            pass
         
        print(f"Episode {i+1}/{episode}: {r_total}\n====\n\n")
        results.append(r_total)
       

    end_time = time.time()
    print(f"Cost time : {end_time - start_time:.3f}s-----\n\n")

    result_dict = {
        "input": variant,
        "raw_results": results,
        "mean_result": int(np.mean(results)),
    }
    for (k,v) in result_dict.items():
        print(f'{k}: {v}')

    if variant['save']:
        save_result(variant, result_dict, p0_algo, p1_algo, episode, layout, horizon)


def str2bool(value):
    if isinstance(value, bool):
        return value
    if value.lower() in ('true', '1', 'yes', 'y'):
        return True
    elif value.lower() in ('false', '0', 'no', 'n'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')
    
if __name__ == '__main__':

    '''
    python main.py --layout cramped_room --p0 Greedy --p1 Greedy --horizon 100
    python main.py --layout cramped_room --p0 ProAgent --p1 BC --horizon 400 -pl l2-ap
    '''
    parser = ArgumentParser(description='OvercookedAI Experiment')

    # these are basis parses
    parser.add_argument('--layout', '-l', type=str, default='cramped_room', choices=['cramped_room', 'asymmetric_advantages', 'coordination_ring', 'forced_coordination', 'counter_circuit'])
    parser.add_argument('--p0',  type=str, default='ProAgent', choices=['ProAgent', 'Greedy', 'COLE', 'FCP', 'MEP', 'PBT', 'SP', 'BC', 'Random', 'Stay', 'Human', "None"], help='Algorithm for P0 agent 0')
    parser.add_argument('--p1', type=str, default='FCP', choices=['ProAgent', 'Greedy', 'COLE', 'FCP', 'MEP', 'PBT', 'SP', 'BC', 'Random', 'Stay', 'Human', "None"], help='Algorithm for P1 agent 1')
    parser.add_argument('--horizon', type=int, default=400, help='Horizon steps in one game')
    parser.add_argument('--episode', type=int, default=1, help='Number of episodes')

    # these parsers are only required when using ProAgent.
    parser.add_argument('--gpt_planner_model', type=str, default="meta-llama/Llama-3.3-70B-Instruct", choices=["meta-llama/Meta-Llama-3-8B-Instruct", "meta-llama/Llama-3.3-70B-Instruct", "google/gemma-1.1-7b-it", "Qwen/Qwen2.5-14B-Instruct-1M", "command-r", "command-r-plus"], help='Number of episodes')
    parser.add_argument('--gpt_explainer_model', type=str, default='command-r', choices=[ "meta-llama/Meta-Llama-3-8B-Instruct", "command-r", "command-r-plus"], help='Number of episodes')    
    parser.add_argument('--prompt_level', '-pl', type=str, default='l2-ap', choices=['l1-p', 'l2-ap', 'l3-aip'], help="'l1-p': make plans directly without CoT; 'l2-ap': plans with analysis; 'l3-aip': plans with analysis and intention.")
    parser.add_argument('--belief_revision', '-br', type=str2bool, default=False, help='whether we use belief_revision or not')
    parser.add_argument('--train_SCM', '-scm', type=str2bool, default=False, help='Train SCM or not')
    parser.add_argument('--train_SCM_step', type=int, default=50000, help='Number of SCM training steps')
    parser.add_argument('--use_causal_graph', '-cg', type=str2bool, default=True, help='whether to use causal graph or not?')
    parser.add_argument('--tunning', type=str2bool, default=False, help="The tuning parameter.")
    parser.add_argument('--gamma', type=float, default=0.5, help='gamma value')
    parser.add_argument('--use_failure_handled', '-fh', type=str2bool, default=True, help='whether to use causal graph as backup or not?')
    parser.add_argument('--record_action_only', '-ra', type=str2bool, default=False, help='whether to record movement action')


    parser.add_argument('--retrival_method', type=str, default="recent_k", choices=['recent_k', 'bert_topk'], help='Use similarity-based(BERT, CLIP) retrieval or retrieve recent K history in dialog.')
    parser.add_argument('--K', type=int, default=1, help="The number of dialogues you want to retrieve.")
    parser.add_argument('--seed', type=int, default=1, help="The seed number.")

    # parameter for save and load results
    parser.add_argument('--mode', type=str, default='exp', choices=['exp', 'demo'], help='exp mode run step-by-step, demo mode run via traj')      
    parser.add_argument('--collect_data', type=str, default='p1', choices=['p1', 'p0'], help='collect_data_for_p0_or_p1')                                
                          
    parser.add_argument('--save', type=str2bool, default=True, help='Whether save the result')
    parser.add_argument('--save_buffer', type=str2bool, default=False, help='Whether save the data_buffer')
    parser.add_argument('--buffer_path', '-bp', type=str, default='output_file_with_goal_and_op_and_rew_200k_aa_true_p0.pt', help='buffer path')                                
    parser.add_argument('--causal_graph_path', '-cgp', type=str, default="edge_params_with_action_after_200k_train_50k_test_w_op_final_cr_act_only_p0.pt", help='causal graph path')                                

    parser.add_argument('--log_dir', type=str, default=None, help='dir to save result')
    parser.add_argument('--debug', type=str2bool, default=True, help='debug mode')

    args = parser.parse_args()
    variant = vars(args)
    if variant["layout"] == "cramped_room":
        layout = "cr" 
    elif variant["layout"] == "asymmetric_advantages":
        layout = "aa"
    elif variant["layout"] == "coordination_ring":
        layout = "cor"
    elif variant["layout"] == "forced_coordination":
        layout = "fc"
    elif variant["layout"] == "counter_circuit":
        layout = "cc"
    
    if variant["p0"] == "ProAgent":
        variant["causal_graph_path"] = f"edge_params_with_action_after_200k_train_50k_test_w_op_final_{layout}_act_only_p0.pt"
    elif variant["p1"] == "ProAgent":
        variant["causal_graph_path"] = f"edge_params_with_action_after_200k_train_50k_test_w_op_final_{layout}_act_only_p1.pt"

    print(variant["causal_graph_path"])
    print(variant)
    start_time = time.time()

    main(variant)
    end_time = time.time()
    print(f"\n=======Finshed all=========\n")
    print(f"Cost time : {end_time - start_time:.3f}s-----\n\n")
